from pymongo import MongoClient
from sklearn.model_selection import train_test_split
import numpy as np
import pickle as pkl
from tokenizer import tokenize
from sklearn import preprocessing

client = MongoClient(port=27017)
db=client['tiktok']
sequence_len = 30
version_num='d29'
num_total=50000
idlist=[]
idlist0=[]
idlist1=[]
img={}
yamnet={}
texts={}
labels={}
edits={}
htids={}
c=0
collist=['happyholidays','RoomTour', 'makeitvogue','wip','haventseen','holidayvibes','ootd','bekind','personalfinance','cozyathome','homecooked', 'theatrekids', 'carsoftiktok', 'yougotthis', 'ImAGhost', 'holidaytiktok', 'mycostume', 'halloweenlook']
#collist=['happyholidays','holidayvibes', 'yougotthis', 'ImAGhost', 'holidaytiktok', 'mycostume', 'halloweenlook','meleaving','yellow','carsoftiktok']
#collist=['wip','bekind','Roomtour', 'theatrekids', 'carsoftiktok','motivationmonday']
#collist=['makeitvogue','wip','haventseen','holidayvibes','ootd','bekind','personalfinance','cozyathome','gamenight','happyholidays','veteransday','Roomtour', 'meleaving', 'yellow', 'theatrekids', 'carsoftiktok', 'yougotthis', 'ImAGhost', 'holidaytiktok','mycostume','halloweenlook','happyhalloween','homecooked']
collist=['makeitvogue','wip','haventseen','holidayvibes','ootd','bekind','personalfinance','cozyathome','gamenight','happyholidays','veteransday','Roomtour', 'meleaving', 'yellow', 'theatrekids', 'carsoftiktok', 'yougotthis', 'ImAGhost', 'holidaytiktok','mycostume','halloweenlook','happyhalloween','homecooked']

#for col in db.list_collection_names():
for col in collist:
    print(col)
    if c >= num_total:
        break
    cursor = db[col].find(no_cursor_timeout=True)
    for obj in cursor:
        if c>=num_total:
            break
        #if len(obj['video_feature']['img_embed'])>0 and len(obj['video_feature']['audio']['yamnet'])>0 and len(obj['text_feature']['text'])>0 and(len(obj['video_feature']['label'])) and 'aesthetics_feature' in obj['video_feature']['editing'].keys() and not np.isnan(obj['video_feature']['editing']['aesthetics_feature']).any():
        if len(obj['video_feature']['img_embed']) > 0 and len(obj['text_feature']['text'])>0 and len(obj['video_feature']['audio']['yamnet'])>0 and(len(obj['video_feature']['label'])>0)and (len(obj['video_feature']['residual'])>0)and('labelA_n' in obj['video_feature']['label'].keys()):
        #if len(obj['video_feature']['img_embed'])>0 and len(obj['video_feature']['audio']['yamnet'])>0 and len(obj['text_feature']['text'])>0 and(len(obj['video_feature']['label'])>0) and (len(obj['video_feature']['residual'])>0) :
            c+=1
            if int(obj['video_feature']['label']['labelA_n'])==0:
                idlist0.append(obj['_id'])
            if int(obj['video_feature']['label']['labelA_n']) == 1:
                idlist1.append(obj['_id'])
            maximg=obj['video_feature']['img_embed']['0']
            for key in obj['video_feature']['img_embed'].keys():
                for i in range(len(obj['video_feature']['img_embed'][key])):
                    if maximg[i]<obj['video_feature']['img_embed'][key][i]:
                        maximg[i]=obj['video_feature']['img_embed'][key][i]
            img[obj['_id']]=maximg
            yamnet[obj['_id']]=obj['video_feature']['audio']['yamnet']

            temptext=obj['text_feature']['text']
            for item in obj['text_feature']['stickerText']:
                temptext+=' '+item
            texts[obj['_id']]=temptext
            #labels[obj['_id']]=[int(obj['video_feature']['label']['labelA']),int(obj['video_feature']['label']['labelB']),float(obj['video_feature']['residual']['residualA']),float(obj['video_feature']['residual']['residualB'])]
            labels[obj['_id']] = [int(obj['video_feature']['label']['labelA']),
                              int(obj['video_feature']['label']['labelB']),
                              int(obj['video_feature']['label']['labelA_n']),
                              int(obj['video_feature']['label']['labelB_n'])]

            htids[obj['_id']]=col
            edits[obj['_id']]=[obj['video_feature']['editing']['video_len'],obj['video_feature']['editing']['var_sb'],obj['video_feature']['editing']['var_sb_c'],obj['video_feature']['editing']['var_vgg'],obj['video_feature']['editing']['var_vgg_c'],obj['video_feature']['editing']['var_yamnet'],obj['video_feature']['editing']['sticker_num'],obj['video_feature']['editing']['avg_sticker_length']]
            #edits[obj['_id']].extend(list(preprocessing.normalize([obj['video_feature']['editing']['aesthetics_feature']])[0]))
    cursor.close()
idres=[]
idres,idlist=train_test_split(idlist1,test_size=round(1*len(idlist0)))
idlist.extend(idlist0)
d_train,d_test=train_test_split(idlist,test_size=0.2)
embedding_matrix = np.load('E:\\data_pi\\embedding_matrix.npy')
embedding_matrix_norm = np.load('E:\\data_pi\\embedding_matrix_norm.npy')
word_index = pkl.load(open('E:\\data_pi\\word_index.pkl','rb'))

train_img_emb=[]
test_img_emb=[]
train_yamnet=[]
test_yamnet=[]
train_text=[]
test_text=[]
train_text_embed=[]
train_text_embed_norm=[]
test_text_embed=[]
test_text_embed_norm=[]
train_labels=[]
test_labels=[]
train_edits=[]
test_edits=[]
with open('E:\\data_pi\\train_ids_'+version_num+'.txt','w', encoding='utf-8',newline='\n') as fin:
    for item in d_train:
        fin.write(htids[item]+'\t'+item+'\n')
        train_img_emb.append(img[item])
        train_yamnet.append(yamnet[item])
        words = tokenize(texts[item])
        text = []
        text_embed = []
        text_embed_norm = []
        for word in words[:sequence_len]:
            if word in word_index:
                text.append(word_index[word])
            else:
                continue
            text_embed.append(embedding_matrix[text[-1]])
            text_embed_norm.append(embedding_matrix_norm[text[-1]])
        while len(text) < sequence_len:
            text.append(0)
            text_embed.append(np.zeros(embedding_matrix.shape[1]))
            text_embed_norm.append(np.zeros(embedding_matrix_norm.shape[1]))
        train_text.append(text)
        train_text_embed.append(text_embed)
        train_text_embed_norm.append(text_embed_norm)
        train_labels.append(labels[item])
        train_edits.append(edits[item])
with open('E:\\data_pi\\test_ids_'+version_num+'.txt','w', encoding='utf-8',newline='\n') as fin:
    for item in d_test:
        fin.write(htids[item]+'\t'+item+'\n')
        test_img_emb.append(img[item])
        test_yamnet.append(yamnet[item])
        words = tokenize(texts[item])
        text = []
        text_embed = []
        text_embed_norm = []
        for word in words[:sequence_len]:
            if word in word_index:
                text.append(word_index[word])
            else:
                continue
            text_embed.append(embedding_matrix[text[-1]])
            text_embed_norm.append(embedding_matrix_norm[text[-1]])
        while len(text) < sequence_len:
            text.append(0)
            text_embed.append(np.zeros(embedding_matrix.shape[1]))
            text_embed_norm.append(np.zeros(embedding_matrix_norm.shape[1]))
        test_text.append(text)
        test_text_embed.append(text_embed)
        test_text_embed_norm.append(text_embed_norm)
        test_labels.append(labels[item])
        test_edits.append(edits[item])


train_img_emb = np.array(train_img_emb).astype(np.float32)
test_img_emb = np.array(test_img_emb).astype(np.float32)
train_yament = np.array(train_yamnet).astype(np.float32)
test_yamnet = np.array(test_yamnet).astype(np.float32)

train_edits = np.array(train_edits).astype(np.float32)
test_edits = np.array(test_edits).astype(np.float32)
np.save('E:\\data_pi\\train_image_embed_'+version_num, np.array(train_img_emb))
np.save('E:\\data_pi\\test_image_embed_'+version_num, np.array(test_img_emb))
np.save('E:\\data_pi\\train_yamnet_embed_'+version_num, np.array(train_yamnet))
np.save('E:\\data_pi\\test_yamnet_embed_'+version_num, np.array(test_yamnet))
np.save('E:\\data_pi\\train_label_'+version_num, np.array(train_labels))
np.save('E:\\data_pi\\test_label_'+version_num, np.array(test_labels))
np.save('E:\\data_pi\\train_text_'+version_num, train_text)
np.save('E:\\data_pi\\test_text_'+version_num, test_text)
np.save('E:\\data_pi\\train_edit_embed_'+version_num, train_edits)
np.save('E:\\data_pi\\test_edit_embed_'+version_num, test_edits)


np.save('E:\\data_pi\\train_text_embed_'+version_num, train_text_embed)
np.save('E:\\data_pi\\train_text_embed_norm_'+version_num, train_text_embed_norm)
np.save('E:\\data_pi\\test_text_embed_'+version_num, test_text_embed)
np.save('E:\\data_pi\\test_text_embed_norm_'+version_num, test_text_embed_norm)

